import time
g_optimizer=tf.keras.optimizers.Adam(0.0002)
d_optimizer=tf.keras.optimizers.Adam(0.0002)
if mode_z=='uniform':
fixed_z=tf.random.uniform(shape=(batch_size, z_size), minval=-1, maxval=1)
elif mode_z=='normal':
fixed_z=tf.random.uniform(shape=(batch_size, z_size))
def create_samples(g_model, input_z):
g_output=g_model(input_z, training=False)
images=tf.reshape(g_output, (batch_size, *image_size))
return (images+1)/2.0
all_losses=[]
epoch_samples=[]
start_time=time.time()
for epoch in range(1, num_epochs+1):
epoch_losses=[]
for i, (input_z, input_real) in enumerate(mnist_trainset):
with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
g_output=gen_model(input_z, training=True)
d_critics_real=disc_model(input_real, training=True)
d_critics_fake=disc_model(g_output, training=True)
g_loss=-tf.math.reduce_mean(d_critics_fake)
d_loss_real=-tf.math.reduce_mean(d_critics_real)
d_loss_fake=tf.math.reduce_mean(d_critics_fake)
d_loss=d_loss_real+d_loss_fake
with tf.GradientTape() as gp_tape:
alpha=tf.random.uniform(shape=[d_critics_real.shape[0], 1, 1, 1], minval=0.0, maxval=1.0)
interpolated=(alpha*input_real+(1-alpha)*g_output)
gp_tape.watch(interpolated)
d_critics_intp=disc_model(interpolated)
grads_intp=gp_tape.gradient(d_critics_intp, [interpolated, ])[0]
grads_intp_l2=tf.sqrt(tf.reduce_sum(tf.square(grads_intp), axis=[1, 2, 3]))
grad_penalty=tf.reduce_mean(tf.square(grads_intp_l2-1.0))
d_loss=d_loss+lambda_gp*grad_penalty
d_grads=d_tape.gradient(d_loss, disc_model.trainable_variables)
d_optimizer.apply_gradients(grads_and_vars=zip(d_grads, disc_model.trainable_variables))
g_grads=g_tape.gradient(g_loss, gen_model.trainable_variables)
g_optimizer.apply_gradients(grads_and_vars=zip(g_grads, gen_model.trainable_variables))
epoch_losses.append((g_loss.numpy(), d_loss.numpy(), d_loss_real.numpy(), d_loss_fake.numpy()))
all_losses.append(epoch_losses)
print('에포크 {:03d} | 시간 {:.2f} min | 평균 손실 >> 생성자/판별자 {:6.2f}/{:6.2f} [판별자-진짜: {:6.2f} 판별자-가짜: {:6.2f}]'.format(epoch, (time.time()-start_time)/60, *list(np.mean(all_losses[-1], axis=0))))
epoch_samples.append(create_samples(gen_model, fixed_z).numpy())